package com.example.autumn.web;

import com.example.autumn.annotation.*;
import com.example.autumn.context.ApplicationContext;
import com.example.autumn.context.BeanDefinition;
import com.example.autumn.context.ConfigurableApplicationContext;
import com.example.autumn.exeception.ErrorResponseException;
import com.example.autumn.exeception.NestedRuntimeException;
import com.example.autumn.exeception.ServerErrorException;
import com.example.autumn.exeception.ServerWebInputException;
import com.example.autumn.io.PropertyResolver;
import com.example.autumn.utils.ClassUtils;
import com.example.autumn.web.utils.JsonUtils;
import com.example.autumn.web.utils.PathUtils;
import com.example.autumn.web.utils.WebUtils;
import jakarta.servlet.ServletConfig;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author liuzhiyong
 * @date 2023/11/8
 * description: 路由解析
 */
public class DispatcherServlet extends HttpServlet {

    final Logger logger = LoggerFactory.getLogger(getClass());

    /**
     * IOC 容器
     */
    ApplicationContext applicationContext;

    /**
     * 师徒解析器
     */
    ViewResolver viewResolver;

    /**
     * 资源路径
     */
    String resourcePath;

    /**
     * 图标路径
     */
    String faviconPath;

    /**
     * Get 请求的路由映射处理器 描述一个一个的接口, 提供构建参数,反射调用原函数的功能
     */
    List<Dispatcher> getDispatchers = new ArrayList<>();

    /**
     * Post 请求的路由映射处理器  描述一个一个的接口, 提供构建参数,反射调用原函数的功能
     */
    List<Dispatcher> postDispatchers = new ArrayList<>();

    public DispatcherServlet(ApplicationContext applicationContext,
                             PropertyResolver propertyResolver) {
        this.applicationContext = applicationContext;
        this.viewResolver = applicationContext.getBean(ViewResolver.class);
        this.resourcePath = propertyResolver.getProperty("${autumn.web.static-path:/static/}");
        this.faviconPath = propertyResolver.getProperty("${autumn.web.favicon-path:/favicon.ico}");
        if (!this.resourcePath.endsWith("/")) {
            this.resourcePath = this.resourcePath + "/";
        }
    }

    /**
     * 初始化
     */
    @Override
    public void init() throws ServletException {
        logger.info("init {}.", getClass().getName());
        // 扫描@Controller 和 @RestController注解
        // 获取所有的BeanDefinition集合, 遍历集合
        for (BeanDefinition def : ((ConfigurableApplicationContext) this.applicationContext).findBeanDefinitions(Object.class)) {
            // 获取class对象
            Class<?> beanClass = def.getBeanClass();
            // 获取bean
            Object bean = def.getRequiredInstance();
            Controller controller = beanClass.getAnnotation(Controller.class);
            RestController restController = beanClass.getAnnotation(RestController.class);
            if (controller != null && restController != null) {
                // 类上不允许同事存在@Controller和@RestController注解
                throw new ServletException("Found @Controller and @RestController on class: " + beanClass.getName());
            }
            if (controller != null) {
                addController(false, def.getName(), bean);
            }
            if (restController != null) {
                addController(true, def.getName(), bean);
            }
        }
    }

    /**
     * 销毁方法
     */
    @Override
    public void destroy() {
        // 销毁IOC容器
        this.applicationContext.close();
    }

    /**
     * 注册Controller
     *
     * @param isRest 是否是Rest请求
     * @param name 当前Bean的名称
     * @param instance 当前Bean的实例
     * @author liuzhiyong
     * @date 2023/11/9
     */
    void addController(boolean isRest, String name, Object instance) throws ServletException {
        logger.info("add {} controller '{}': {}", isRest ? "REST" : "MVC", name, instance.getClass().getName());
        addMethods(isRest, name, instance, instance.getClass());
    }

    /**
     * 添加方法 --> 接口
     * 扫描加了@PostMapping和@GetMapping注解的方法
     * 允许两个注解同时存在
     *
     * @param isRest 是否是Rest请求
     * @param name 当前Bean名称
     * @param instance 当前Bean实例
     * @param type 当前Bean的Class
     * @author liuzhiyong
     * @date 2023/11/9
     */
    void addMethods(boolean isRest, String name, Object instance, Class<?> type) throws ServletException {
        for (Method m : type.getDeclaredMethods()) {
            GetMapping get = m.getAnnotation(GetMapping.class);
            if (get != null) {
                // 校验
                checkMethod(m);
                // 构建Dispatcher, 添加到集合
                this.getDispatchers.add(new Dispatcher("GET", isRest, instance, m, get.value()));
            }
            PostMapping post = m.getAnnotation(PostMapping.class);
            if (post != null) {
                // 校验
                checkMethod(m);
                // 构建Dispatcher, 添加到集合
                this.postDispatchers.add(new Dispatcher("POST", isRest, instance, m, post.value()));
            }
        }
        // 处理父类
        Class<?> superClass = type.getSuperclass();
        if (superClass != null) {
            addMethods(isRest, name, instance, superClass);
        }
    }

    /**
     * 校验方法对象, 不允许static方法作为接口
     *
     * @param m 方法对象
     * @author liuzhiyong
     * @date 2023/11/9
     */
    void checkMethod(Method m) throws ServletException {
        int mod = m.getModifiers();
        if (Modifier.isStatic(mod)) {
            throw new ServletException("Cannot do URL mapping to static method: " + m);
        }
        m.setAccessible(true);
    }

    /**
     *处理Get请求
     *
     * @param req 请求
     * @param resp 系那个硬
     * @author liuzhiyong
     * @date 2023/11/9
     */
    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        String url = req.getRequestURI();
        if (url.equals(this.faviconPath) || url.startsWith(this.resourcePath)) {
            doResource(url, req, resp);
        } else {
            doService(req, resp, this.getDispatchers);
        }
    }

    /**
     * 处理post请求
     *
     * @param req 请求
     * @param resp 响应
     * @author liuzhiyong
     * @date 2023/11/9
     */
    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doService(req, resp, this.postDispatchers);
    }

    void doService(HttpServletRequest req, HttpServletResponse resp, List<Dispatcher> dispatchers) throws ServletException, IOException {
        String url = req.getRequestURI();
        try {
            doService(url, req, resp, dispatchers);
        } catch (ErrorResponseException e) {
            logger.warn("process request failed with status " + e.statusCode + " : " + url, e);
            // 判断响应是否已经提交
            if (!resp.isCommitted()) {
                resp.resetBuffer();
                resp.sendError(e.statusCode);
            }
        } catch (RuntimeException | ServletException | IOException e) {
            logger.warn("process request failed: " + url, e);
            throw e;
        } catch (Exception e) {
            logger.warn("process request failed: " + url, e);
            throw new NestedRuntimeException(e);
        }
    }

    void doService(String url, HttpServletRequest req, HttpServletResponse resp, List<Dispatcher> dispatchers) throws Exception {
        // 遍历dispatchers,正则匹配,执行原函数
        for (Dispatcher dispatcher : dispatchers) {
            Result process = dispatcher.process(url, req, resp);
            if (process.processed()) {
                Object r = process.returnObject();
                if (dispatcher.isRest) {
                    // 如果是Rest风格 ==> 结果转成json字符串
                    if (!resp.isCommitted()) {
                        // 响应未提交
                        resp.setContentType("application/json");
                    }
                    if (dispatcher.isResponseBody) {
                        // 如果有ResponseBody注解
                        if (r instanceof String s) {
                            // 结果类型为字符串
                            PrintWriter pw = resp.getWriter();
                            pw.write(s);
                            pw.flush();
                        } else if (r instanceof byte[] data) {
                            // 如果是流
                            ServletOutputStream output = resp.getOutputStream();
                            output.write(data);
                            output.flush();
                        } else {
                            // 错误
                            throw new ServletException("Unable to process REST result when handle url: " + url);
                        }
                    } else if (!dispatcher.isVoid) {
                        // 处理非空的返回值, 将结果转成json返回
                        PrintWriter pw = resp.getWriter();
                        JsonUtils.writeJson(pw, r);
                        pw.flush();
                    }
                } else {
                    // MVC 视图
                    if (!resp.isCommitted()) {
                        // 如果未提交response
                        resp.setContentType("text/html");
                    }
                    if (r instanceof String s) {
                        // 处理字符串
                        if (dispatcher.isResponseBody) {
                            // 如果加了ResponseBody注解 ==> 直接返回字符串
                            PrintWriter pw = resp.getWriter();
                            pw.write(s);
                            pw.flush();
                        } else if (s.startsWith("redirect:")) {
                            // 重定向
                            resp.sendRedirect(s.substring(9));
                        } else {
                            // 错误
                            throw new ServerErrorException("Unable to process String result when handle url: " + url);
                        }
                    } else if (r instanceof byte[] data) {
                        // 处理 流
                        if (dispatcher.isResponseBody) {
                            // 如果加了ResponseBody注解
                            ServletOutputStream output = resp.getOutputStream();
                            output.write(data);
                            output.flush();
                        } else {
                            // 错误
                            throw new ServletException("Unable to process byte[] result when handle url: " + url);
                        }
                    } else if (r instanceof ModelAndView mv) {
                        // 获取视图名称
                        String view = mv.getViewName();
                        if (view.startsWith("redirect:")) {
                            //重定向
                            resp.sendRedirect(view.substring(9));
                        } else {
                            // 渲染视图
                            this.viewResolver.render(view, mv.getModel(), req, resp);
                        }
                    } else if (!dispatcher.isVoid && r != null) {
                        // 不过结果不为空, 返回的也不是null. ==> 返回的不是ModelAndView 抛出异常
                        throw new ServletException("Unable to process " + r.getClass().getName() + " result when handle url: " + url);
                    }
                }
                return;
            }
        }
        // 没有找到
        resp.sendError(404, "NOT Found");
    }

    void doResource(String url, HttpServletRequest req, HttpServletResponse resp) throws IOException {
        ServletContext ctx = req.getServletContext();
        try (InputStream input = ctx.getResourceAsStream(url)) {
            if (input == null) {
                resp.sendError(404, "Not Found");
            } else {
                String file = url;
                int n = url.lastIndexOf('/');
                if (n >= 0) {
                    file = url.substring(n + 1);
                }
                String mime = ctx.getMimeType(file);
                if (mime == null) {
                    mime = "application/octet-stream";
                }
                resp.setContentType(mime);
                ServletOutputStream output = resp.getOutputStream();
                input.transferTo(output);
                output.flush();
            }
        }
    }

    /**
     * 描述控制器和url映射关系的处理器
     */
    static class Dispatcher {

        final static Result NOT_PROCESSED = new Result(false, null);

        final Logger logger = LoggerFactory.getLogger(getClass());

        /**
         * 是否返回Rest
         */
        boolean isRest;

        /**
         * 是否有@ResponseBody
         */
        boolean isResponseBody;

        /**
         * 是否返回void
         */
        boolean isVoid;

        /**
         * url正则匹配
         * 将接口url转成正则, 请求过来时通过正则去匹配
         */
        Pattern urlPattern;

        /**
         * Bean实例
         */
        Object controller;

        /**
         * 处理方法
         */
        Method handlerMethod;

        /**
         * 方法参数数组
         */
        Param[] methodParameters;

        public Dispatcher(String httpMethod, boolean isRest, Object controller, Method method, String urlPattern) throws ServletException {
            this.isRest = isRest;
            this.isResponseBody = method.getAnnotation(ResponseBody.class) != null;
            this.isVoid = method.getReturnType() == void.class;
            this.urlPattern = PathUtils.compile(urlPattern);
            this.controller = controller;
            this.handlerMethod = method;
            Parameter[] params = method.getParameters();
            Annotation[][] paramsAnnos = method.getParameterAnnotations();
            this.methodParameters = new Param[params.length];
            for (int i = 0; i < params.length; i++) {
                this.methodParameters[i] = new Param(httpMethod, method, params[i], paramsAnnos[i]);
            }
            logger.atDebug().log("mapping {} to handler {}.{}", urlPattern, controller.getClass().getSimpleName(), method.getName());
            if (logger.isDebugEnabled()) {
                for (Param p : this.methodParameters) {
                    logger.debug("> parameter: {}", p);
                }
            }
        }

        /**
         * 反射执行原函数
         * @param url 请求的url
         * @param request request
         * @param response response
         * @return {@link Result} 结果
         * @throws Exception 异常
         */
        Result process(String url, HttpServletRequest request, HttpServletResponse response) throws Exception {
            // 匹配当前的url, 匹配上的话, 处理参数, 反射执行原函数
            Matcher matcher = urlPattern.matcher(url);
            if (matcher.matches()) {
                // 请求参数数组
                Object[] arguments = new Object[this.methodParameters.length];
                // 构建请求参数数组
                for (int i = 0; i < arguments.length; i++) {
                    // 获取参数
                    Param param = methodParameters[i];
                    arguments[i] = switch (param.paramType) {
                        case PATH_VARIABLE -> {
                            try {
                                // 获取匹配的字符串 ==> 路径参数字符串
                                String s = matcher.group(param.name);
                                // 路径参数可以出现字符串和数字 ==> 转成原类型
                                yield convertToType(param.classType, s);
                            } catch (IllegalArgumentException e) {
                                throw new ServerWebInputException("路径参数 '" + param.name + "'没有找到");
                            }
                        }
                        case REQUEST_BODY -> {
                            // 请求体JSON
                            BufferedReader reader = request.getReader();
                            yield JsonUtils.readJson(reader, param.classType);
                        }
                        case REQUEST_PARAM -> {
                            String s = getOrDefault(request, param.name, param.defaultValue);
                            yield convertToType(param.classType, s);
                        }
                        case SERVLET_VARIABLE -> {
                            Class<?> classType = param.classType;
                            if (classType == HttpServletRequest.class) {
                                yield request;
                            } else if (classType == HttpServletResponse.class) {
                                yield response;
                            } else if (classType == HttpSession.class) {
                                yield request.getSession();
                            } else if (classType == ServletContext.class) {
                                yield request.getServletContext();
                            } else {
                                throw new ServerErrorException("Could not determine argument type: " + classType);
                            }
                        }
                    };
                }
                Object result = null;
                try {
                    result = this.handlerMethod.invoke(this.controller, arguments);
                } catch (InvocationTargetException e) {
                    Throwable t = e.getCause();
                    if (t instanceof Exception ex) {
                        // 抛出的是Exception类型
                        throw ex;
                    }
                    throw e;
                }
                return new Result(true, result);
            }
            return NOT_PROCESSED;
        }

        Object convertToType(Class<?> classType, String s) {
            if (classType == String.class) {
                return s;
            } else if (classType == boolean.class || classType == Boolean.class) {
                return Boolean.valueOf(s);
            } else if (classType == int.class || classType == Integer.class) {
                return Integer.valueOf(s);
            } else if (classType == long.class || classType == Long.class) {
                return Long.valueOf(s);
            } else if (classType == byte.class || classType == Byte.class) {
                return Byte.valueOf(s);
            } else if (classType == short.class || classType == Short.class) {
                return Short.valueOf(s);
            } else if (classType == float.class || classType == Float.class) {
                return Float.valueOf(s);
            } else if (classType == double.class || classType == Double.class) {
                return Double.valueOf(s);
            } else {
                throw new ServerErrorException("Could not determine argument type: " + classType);
            }
        }

        /**
         * 获取请求参数
         * 地址上用? 拼接的参数 或者表单提交的参数
         * @param request request
         * @param name 参数名
         * @param defaultValue 注解上默认值
         * @return {@link String } 参数值
         * @author liuzhiyong
         * @date 2023/11/9
         */
        String getOrDefault(HttpServletRequest request, String name, String defaultValue) {
            // 获取参数
            String s = request.getParameter(name);
            if (s == null) {
                if (WebUtils.DEFAULT_PARAM_VALUE.equals(defaultValue)) {
                    // 说明在注解中没有执行默认值
                    throw new ServerWebInputException("Request parameter '" + name + "' not found.");
                }
                // 注解中指定了默认值, 返回默认值
                return defaultValue;
            }
            return s;
        }

    }

    /**
     * 参数类型
     */
    static enum ParamType {

        /**
         * 路径参数
         */
        PATH_VARIABLE,

        /**
         * 请求链接参数, ?拼接的参数 或者表单提交的参数
         */
        REQUEST_PARAM,

        /**
         * 请求体参数 Post请求传递的JSON
         */
        REQUEST_BODY,

        /**
         * Servlet参数, HttpServletRequest等Servlet API提供的参数，直接从DispatcherServlet的方法参数获得
         */
        SERVLET_VARIABLE;
    }

    /**
     * 接口的参数
     */
    static class Param {

        /**
         * 参数名称
         * 有@PathVariable注解和@RequestParam注解指定的参数名称
         */
        String name;

        /**
         * 参数类型
         */
        ParamType paramType;

        /**
         * 参数的class对象
         */
        Class<?> classType;

        /**
         * 默认值
         */
        String defaultValue;

        /**
         * 构造函数
         *
         * @param httpMethod 方法请求方式名称 GET或者POST
         * @param method 方法
         * @param parameter 参数
         * @param annotations 参数的注解数组
         * @author liuzhiyong
         * @date 2023/11/9
         */
        public Param(String httpMethod, Method method, Parameter parameter, Annotation[] annotations) throws ServletException {
            // 从注解数组中获取PathVariable注解
            PathVariable pv = ClassUtils.getAnnotation(annotations, PathVariable.class);
            // 从注解数组中获取RequestParam注解
            RequestParam rp = ClassUtils.getAnnotation(annotations, RequestParam.class);
            // 从注解数组中获取RequestBody注解
            RequestBody rb = ClassUtils.getAnnotation(annotations, RequestBody.class);
            // 参数只能包含一个注解
            int total = (pv == null ? 0 : 1) + (rp == null ? 0 : 1) + (rb == null ? 0 : 1);
            if (total > 1) {
                throw new ServletException("Annotation @PathVariable, @RequestParam and @RequestBody cannot be combined at method: " + method);
            }
            // 获取参数的class对象
            this.classType = parameter.getType();
            // 处理参数注解
            if (pv != null) {
                // 路径参数
                this.name = pv.value();
                this.paramType = ParamType.PATH_VARIABLE;
            } else if (rp != null) {
                // 地址参数
                this.name = rp.value();
                this.defaultValue = rp.defaultValue();
                this.paramType = ParamType.REQUEST_PARAM;
            } else if (rb != null) {
                // 请求体参数
                this.paramType = ParamType.REQUEST_BODY;
            } else {
                // Servlet参数
                this.paramType = ParamType.SERVLET_VARIABLE;
                // 校验Servlet参数类型
                if (this.classType != HttpServletRequest.class
                        && this.classType != HttpServletResponse.class
                        && this.classType != HttpSession.class
                        && this.classType != ServletContext.class) {
                    throw new ServerErrorException("(Missing annotation?) Unsupported argument type: " + classType + " at method: " + method);
                }
            }

        }

        @Override
        public String toString() {
            return "Param{" +
                    "name='" + name + '\'' +
                    ", paramType=" + paramType +
                    ", classType=" + classType +
                    ", defaultValue='" + defaultValue + '\'' +
                    '}';
        }
    }

    /**
     * 统一返回结果
     * @param processed 是否处理 true=已处理/false=未处理
     * @param returnObject 结果数据
     */
    static record Result(boolean processed, Object returnObject) {
    }

}
